import torch
import numpy as np
import os
from tqdm import tqdm
import sys
import json
import matplotlib.pyplot as plt

CODE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(CODE_DIR)

from subprocess import call
from tree import PartTree, save_tree, load_tree
from graph import get_tree
from openclip_utils import OpenCLIPNetwork, OpenCLIPNetworkConfig
from eval_clip import evaluate_iou
from tree_query_qwen import query_tree, load_model_and_processor

def read_selected_directories(file_path):
    with open(file_path, "r") as f:
        return [line.strip() for line in f]

PARTNET_DIR = '/home/codeysun/git/data/PartNet/data_v0/'
shape_list_txt = 'selected_partnet_data.txt'
shape_list = set(read_selected_directories(shape_list_txt))
shape_list_extra = [10558, 8677, 555, 14102, 11526, 7289, 23894, 10101, 1135, 1128, 10806]

for shape in shape_list_extra:
    shape_list.add(str(shape))

SHAPE_LIST = list(shape_list)

THRESHOLDS = [0.2, 0.3, 0.4, 0.45, 0.5, 0.55, 0.6, 0.7, 0.75, 0.8]

def run_script(cmd):
    ret = call(cmd, shell=True)
    if ret != 0:
        raise Exception(f"Failed to run {cmd}")
    

def build_trees():
    try:
        for shape in tqdm(SHAPE_LIST):
            shape = str(shape)

            cmd = f'python partnet_tree_vis.py --partnet_dir {PARTNET_DIR} --output {shape}'
            run_script(cmd)
    except Exception as e:
        print(e)

def label_trees():
    DEFAULT_MODEL = "Qwen/Qwen2.5-VL-7B-Instruct"
    model, processor = load_model_and_processor(DEFAULT_MODEL)
    for shape in tqdm(SHAPE_LIST):
        shape = str(shape)
        output_dir = f'vis/partnet/{shape}/tree'
        tree_path = os.path.join(output_dir, 'tree.pkl')

        tree = load_tree(tree_path)
        query_tree(tree, model, processor)
        tree.render_tree(os.path.join(os.path.dirname(tree_path), 'tree_with_labels'))
        save_tree(tree, os.path.join(os.path.dirname(tree_path), 'tree_labeled.pkl'))

def eval():
    stats = {"miou" : [], "accuracy": [], "precision": [], "recall": []}
    stats_clip = {"miou" : [], "accuracy": [], "precision": [], "recall": []}
    model = OpenCLIPNetwork(OpenCLIPNetworkConfig)
    with torch.no_grad():
        for thres in THRESHOLDS:
            miou = 0
            miou_clip = 0

            confmat_total = np.array([[0, 0], [0, 0]])
            confmat_total_clip = np.array([[0, 0], [0, 0]])
            for shape in tqdm(SHAPE_LIST):
                shape = str(shape)
                output_dir = f'vis/partnet/{shape}/tree'
                shape_dir = os.path.join(PARTNET_DIR, shape)
                tree_labeled_path = os.path.join(output_dir, 'tree_labeled.pkl')

                torch.cuda.empty_cache()
                model.set_positives([""])
                tree = load_tree(tree_labeled_path)
                evaluate_iou(tree, model, 'partnet', shape_dir, False, thres)

                output_file = "tree_activations"
                tree.render_tree(os.path.join(os.path.dirname(tree_labeled_path), output_file))

                # Render ground truth tree
                nodes = tree.get_nodes()
                for node in nodes:
                    node.caption = node.gt_caption
                tree.render_tree(os.path.join(os.path.dirname(tree_labeled_path), "tree_gt"))


                iou_path = os.path.join(output_dir, f'iou_{thres:.2f}.txt')
                with open(iou_path, 'r') as f:
                    line = f.readline().strip()
                    line = line.split(':')
                    val = float(line[-1].strip())
                    miou += val

                cm_path = os.path.join(output_dir, f'confmat_{thres:.2f}.txt')
                confmat = np.loadtxt(cm_path, dtype=int)
                confmat_total = confmat_total + confmat

                torch.cuda.empty_cache()
                model.set_positives([""])
                tree = load_tree(tree_labeled_path)
                evaluate_iou(tree, model, 'partnet', shape_dir, True, thres)

                output_file = "tree_activations_CLIP"
                tree.render_tree(os.path.join(os.path.dirname(tree_labeled_path), output_file))


                iou_path = os.path.join(output_dir, f'iou_CLIP_{thres:.2f}.txt')
                with open(iou_path, 'r') as f:
                    line = f.readline().strip()
                    line = line.split(':')
                    val = float(line[-1].strip())
                    miou_clip += val

                cm_path = os.path.join(output_dir, f'confmat_CLIP_{thres:.2f}.txt')
                confmat = np.loadtxt(cm_path, dtype=int)
                confmat_total_clip = confmat_total_clip + confmat

            # Print total metrics
            stats["miou"].append(miou / len(SHAPE_LIST))
            stats["accuracy"].append((confmat_total[0, 0] + confmat_total[1, 1]) / (np.sum(confmat_total)))
            stats["precision"].append(confmat_total[0, 0] / (confmat_total[0, 0] + confmat_total[0, 1]))
            stats["recall"].append(confmat_total[0, 0] / (confmat_total[0, 0] + confmat_total[1, 0]))

            stats_clip["miou"].append(miou_clip / len(SHAPE_LIST))
            stats_clip["accuracy"].append((confmat_total_clip[0, 0] + confmat_total_clip[1, 1]) / (np.sum(confmat_total_clip)))
            stats_clip["precision"].append(confmat_total_clip[0, 0] / (confmat_total_clip[0, 0] + confmat_total_clip[0, 1]))
            stats_clip["recall"].append(confmat_total_clip[0, 0] / (confmat_total_clip[0, 0] + confmat_total_clip[1, 0]))

    return stats, stats_clip


def save_data(stats, stats_clip):
    data_to_save = {
        "thresholds": THRESHOLDS,
        "stats": stats,
        "stats_clip": stats_clip
    }
    with open('stats.txt', 'w') as f:
        json.dump(data_to_save, f, indent=4)

    # Plot the stats
    fig, axs = plt.subplots(2, 2, figsize=(15, 15))
    fig.suptitle('Comparison of Statistics across Thresholds', fontsize=16)

    axs = axs.flatten()

    for i, stat in enumerate(stats.keys()):
        ax = axs[i]
        
        # Plot for stats
        ax.plot(THRESHOLDS, stats[stat], 'b-o', label='Ours')
        
        # Plot for stats_clip
        ax.plot(THRESHOLDS, stats_clip[stat], 'r-o', label='CLIP baseline')
        
        ax.set_xlabel('Threshold')
        ax.set_ylabel(stat.capitalize())
        ax.set_title(f'{stat.capitalize()} vs Threshold')
        ax.legend()
        ax.grid(True)

        # Set x-axis ticks to match THRESHOLDS
        ax.set_xticks(THRESHOLDS)
        ax.set_xticklabels([f'{t:.2f}' for t in THRESHOLDS])

    # Adjust layout and display the plot
    plt.tight_layout()

    plt.savefig("stats_graph.png", dpi=300, bbox_inches='tight')

    print("Figure saved as stats.png")


if __name__ == "__main__":
    # create partnet trees
    # build_trees()

    # label trees with VLM
    # label_trees()

    # evaluate
    stats, stats_clip = eval()

    save_data(stats, stats_clip)
